Matrix Chain Multiplication¶
Algorithm¶
- Create the partition half-table with size $n$ (number of matrices) and iitialize each cell with 0.
- Cost of multiplying 2 matrices of dimensions $(m\times n)$ and $(n\times p) = m\times n\times p$.
- For
k = 0to $n$:- For all indexes
i = 0to $n$:- Create pair-intervals of $i$ to $i + k$ and calculate their individual cost.
- Set the value of the minimum cost in the cell $(i, i + k)$.
- For all indexes
- The value of cell $(0,0)$ is the answer.
Function to create splits¶
In [1]:
from typing import List, Tuple
# function to split an interval
def split_interval(i: int, j: int, verbose: bool = False) -> List[Tuple[int]]:
if verbose:
print(f"\nSplitting ({i}, {j})")
splits = []
for k in range(i, j):
split = ((i, k), (k + 1, j))
splits.append(split)
return splits
Function to get cost of multiplying 2 matrices¶
In [2]:
def getCost(m1: Tuple[int], m2: Tuple[int]) -> int:
if not m1[1] == m2[0]:
raise ValueError("Invalid matrix dimensions")
return m1[0] * m1[1] * m2[1]
In [3]:
from typing import List
from IPython.display import display, HTML
def displayTable(arr: List[List[int]]):
n = len(arr[0])
headers = f'''<thead><tr><td> </td>{''.join([f"<td>{n - i}</td>" for i in range(n)])}</tr></thead>'''
htmlData = "<tbody>"
for i in range(n):
data = f"<tr><td>{i + 1}</td>"
for cell in arr[i]:
t = cell
if cell == '':
t = " "
data += f"<td>{t}</td>"
data += "</tr>"
htmlData += data
htmlData += "</tbody>"
display(HTML(
f"<table>{headers}{htmlData}</table><hr/>"
))
In [4]:
inf = float("inf")
class PartitionTable:
def __init__(self, matrices_: List[Tuple[int]]):
self.matrices = matrices_[:]
self.partition_table = []
self.table_size = len(matrices_)
# TODO: initialize the partition table
k = self.table_size
for i in range(self.table_size):
arr = []
for j in range(k):
arr.append(0)
k -= 1
self.partition_table.append(arr[:])
# function to map index to table
self.map_index = lambda x, y: (x - 1, self.table_size - y)
# setter function
def setValue(self, x_: int, y_: int, val_: str):
x, y = self.map_index(x_, y_)
self.partition_table[x][y] = val_
# getter function
def getValue(self, x_: int, y_: int) -> str:
x, y = self.map_index(x_, y_)
return self.partition_table[x][y]
# TODO: calculate cost of a partition
def getCompositeValue(self, partition: Tuple[Tuple[int]]) -> int:
cell1, cell2 = partition[0], partition[1]
matrix1 = self.matrices[cell1[0] - 1] if cell1[0] == cell1[1] else (
self.matrices[cell1[0] - 1][0], self.matrices[cell1[1] - 1][1]
)
matrix2 = self.matrices[cell2[0] - 1] if cell2[0] == cell2[1] else (
self.matrices[cell2[0] - 1][0], self.matrices[cell2[1] - 1][1]
)
cost = getCost(matrix1, matrix2)
value1 = self.getValue(cell1[0], cell1[1])
value2 = self.getValue(cell2[0], cell2[1])
return value1 + value2 + cost
def createValue(self, x_: int, y_: int, verbose: bool):
if x_ > y_ or not self.getValue(x_, y_) == 0:
return
if x_ == y_:
# self.setValue(x_, y_, 0)
return
currentValue = inf
splits = split_interval(x_, y_, verbose)
for split in splits:
value = self.getCompositeValue(split)
if verbose:
print(f"Cost of split {split[0]} - {split[1]} = {value}")
if value < currentValue:
currentValue = value
if verbose:
print(f"Min. cost of ({x_}, {y_}) = {currentValue}")
self.setValue(x_, y_, currentValue)
def __call__(self, verbose = False) -> int:
for k in range(1, self.table_size):
if verbose:
display(HTML(f"<h2>Step:- {k}</h2>"))
for i in range(1, self.table_size + 1):
start, end = i, i + k
if start > self.table_size or end > self.table_size:
continue
self.createValue(start, end, verbose)
if verbose:
displayTable(self.partition_table)
return self.partition_table[0][0]
Driver code 1¶
In [5]:
matrices = [(10, 20), (20, 30), (30, 40), (40, 50)]
solver = PartitionTable(matrices)
print(f"Answer = {solver(verbose = True)}")
Step:- 1
Splitting (1, 2) Cost of split (1, 1) - (2, 2) = 6000 Min. cost of (1, 2) = 6000 Splitting (2, 3) Cost of split (2, 2) - (3, 3) = 24000 Min. cost of (2, 3) = 24000 Splitting (3, 4) Cost of split (3, 3) - (4, 4) = 60000 Min. cost of (3, 4) = 60000
| 4 | 3 | 2 | 1 | |
| 1 | 0 | 0 | 6000 | 0 |
| 2 | 0 | 24000 | 0 | |
| 3 | 60000 | 0 | ||
| 4 | 0 |
Step:- 2
Splitting (1, 3) Cost of split (1, 1) - (2, 3) = 32000 Cost of split (1, 2) - (3, 3) = 18000 Min. cost of (1, 3) = 18000 Splitting (2, 4) Cost of split (2, 2) - (3, 4) = 90000 Cost of split (2, 3) - (4, 4) = 64000 Min. cost of (2, 4) = 64000
| 4 | 3 | 2 | 1 | |
| 1 | 0 | 18000 | 6000 | 0 |
| 2 | 64000 | 24000 | 0 | |
| 3 | 60000 | 0 | ||
| 4 | 0 |
Step:- 3
Splitting (1, 4) Cost of split (1, 1) - (2, 4) = 74000 Cost of split (1, 2) - (3, 4) = 81000 Cost of split (1, 3) - (4, 4) = 38000 Min. cost of (1, 4) = 38000
| 4 | 3 | 2 | 1 | |
| 1 | 38000 | 18000 | 6000 | 0 |
| 2 | 64000 | 24000 | 0 | |
| 3 | 60000 | 0 | ||
| 4 | 0 |
Answer = 38000
Driver code 2¶
In [6]:
matrices = [(5, 10), (10, 3), (3, 12), (12, 5)]
solver = PartitionTable(matrices)
print(f"Answer = {solver(verbose = True)}")
Step:- 1
Splitting (1, 2) Cost of split (1, 1) - (2, 2) = 150 Min. cost of (1, 2) = 150 Splitting (2, 3) Cost of split (2, 2) - (3, 3) = 360 Min. cost of (2, 3) = 360 Splitting (3, 4) Cost of split (3, 3) - (4, 4) = 180 Min. cost of (3, 4) = 180
| 4 | 3 | 2 | 1 | |
| 1 | 0 | 0 | 150 | 0 |
| 2 | 0 | 360 | 0 | |
| 3 | 180 | 0 | ||
| 4 | 0 |
Step:- 2
Splitting (1, 3) Cost of split (1, 1) - (2, 3) = 960 Cost of split (1, 2) - (3, 3) = 330 Min. cost of (1, 3) = 330 Splitting (2, 4) Cost of split (2, 2) - (3, 4) = 330 Cost of split (2, 3) - (4, 4) = 960 Min. cost of (2, 4) = 330
| 4 | 3 | 2 | 1 | |
| 1 | 0 | 330 | 150 | 0 |
| 2 | 330 | 360 | 0 | |
| 3 | 180 | 0 | ||
| 4 | 0 |
Step:- 3
Splitting (1, 4) Cost of split (1, 1) - (2, 4) = 580 Cost of split (1, 2) - (3, 4) = 405 Cost of split (1, 3) - (4, 4) = 630 Min. cost of (1, 4) = 405
| 4 | 3 | 2 | 1 | |
| 1 | 405 | 330 | 150 | 0 |
| 2 | 330 | 360 | 0 | |
| 3 | 180 | 0 | ||
| 4 | 0 |
Answer = 405
Driver code 3¶
In [7]:
matrices = [(5,4),(4,6),(6,2),(2,7),(7,3)]
solver = PartitionTable(matrices)
print(f"Answer = {solver(verbose = True)}")
Step:- 1
Splitting (1, 2) Cost of split (1, 1) - (2, 2) = 120 Min. cost of (1, 2) = 120 Splitting (2, 3) Cost of split (2, 2) - (3, 3) = 48 Min. cost of (2, 3) = 48 Splitting (3, 4) Cost of split (3, 3) - (4, 4) = 84 Min. cost of (3, 4) = 84 Splitting (4, 5) Cost of split (4, 4) - (5, 5) = 42 Min. cost of (4, 5) = 42
| 5 | 4 | 3 | 2 | 1 | |
| 1 | 0 | 0 | 0 | 120 | 0 |
| 2 | 0 | 0 | 48 | 0 | |
| 3 | 0 | 84 | 0 | ||
| 4 | 42 | 0 | |||
| 5 | 0 |
Step:- 2
Splitting (1, 3) Cost of split (1, 1) - (2, 3) = 88 Cost of split (1, 2) - (3, 3) = 180 Min. cost of (1, 3) = 88 Splitting (2, 4) Cost of split (2, 2) - (3, 4) = 252 Cost of split (2, 3) - (4, 4) = 104 Min. cost of (2, 4) = 104 Splitting (3, 5) Cost of split (3, 3) - (4, 5) = 78 Cost of split (3, 4) - (5, 5) = 210 Min. cost of (3, 5) = 78
| 5 | 4 | 3 | 2 | 1 | |
| 1 | 0 | 0 | 88 | 120 | 0 |
| 2 | 0 | 104 | 48 | 0 | |
| 3 | 78 | 84 | 0 | ||
| 4 | 42 | 0 | |||
| 5 | 0 |
Step:- 3
Splitting (1, 4) Cost of split (1, 1) - (2, 4) = 244 Cost of split (1, 2) - (3, 4) = 414 Cost of split (1, 3) - (4, 4) = 158 Min. cost of (1, 4) = 158 Splitting (2, 5) Cost of split (2, 2) - (3, 5) = 150 Cost of split (2, 3) - (4, 5) = 114 Cost of split (2, 4) - (5, 5) = 188 Min. cost of (2, 5) = 114
| 5 | 4 | 3 | 2 | 1 | |
| 1 | 0 | 158 | 88 | 120 | 0 |
| 2 | 114 | 104 | 48 | 0 | |
| 3 | 78 | 84 | 0 | ||
| 4 | 42 | 0 | |||
| 5 | 0 |
Step:- 4
Splitting (1, 5) Cost of split (1, 1) - (2, 5) = 174 Cost of split (1, 2) - (3, 5) = 288 Cost of split (1, 3) - (4, 5) = 160 Cost of split (1, 4) - (5, 5) = 263 Min. cost of (1, 5) = 160
| 5 | 4 | 3 | 2 | 1 | |
| 1 | 160 | 158 | 88 | 120 | 0 |
| 2 | 114 | 104 | 48 | 0 | |
| 3 | 78 | 84 | 0 | ||
| 4 | 42 | 0 | |||
| 5 | 0 |
Answer = 160